from PIL import Image
import tensorflow.keras.backend as K
from tensorflow.python.keras.layers import *
from tensorflow.python.keras.models import *
from tensorflow.python.keras.optimizers import *
from tensorflow.python.keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint
import tensorflow as tf
import glob,pickle
import random
import time
import numpy as np
import tensorflow.gfile as gfile
import matplotlib.pyplot as mp

NUMBER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

CAPTCHA_CHARSET = NUMBER   # 使用数字字符集生成验证码
CAPTCHA_LEN = 4            # 验证码长度
CAPTCHA_HEIGHT = 60        # 验证码高度
CAPTCHA_WIDTH = 160       # 验证码长度

TRAIN_DATA_DIR = './images/' # 验证码训练数据集路径
TEST_DATA_DIR = './validation/'

BATCH_SIZE = 100    # 每个批次训练样本的数量
EPOCHS = 40        # 模型训练的轮数
LEARN_RATE = 0.1   # 学习率
# OPT = Adam(lr=LEARN_RATE,amsgrad=True)        # 采用adam算法进行模型优化
# OPT = 'RMSprop'
# OPT = 'Nadam'
OPT = Nadam(lr=0.0002)
# OPT = SGD(lr=LEARN_RATE, decay=LEARN_RATE/EPOCHS, momentum=0.9, nesterov=True)
# LOSS = 'binary_crossentropy' # 采用二进制交叉熵损失函数，向量的各分量相互独立
LOSS = 'categorical_crossentropy'

# 模型文件存储路径和文件格式
MODEL_DIR = './model/train_demo/'
MODEL_FORMAT = '.h5'
# 训练记录文件存储路径和文件格式
HISTORY_DIR = './history/train_demo/'
HISTORY_FORMAT = '.history'

# 训练日志内容格式
filename_str = "{} captcha_{}_bs_{}_epochs_{}{}"

# # 模型网络结构文件
# MODEL_VIS_FILE = 'captcha_classification.png'
# 模型文件
MODEL_FILE = filename_str.format(MODEL_DIR , LOSS, str(BATCH_SIZE),
                                str(EPOCHS),MODEL_FORMAT)
# 训练记录文件
HISTORY_FILE = filename_str.format(HISTORY_DIR , LOSS, str(BATCH_SIZE),
                                str(EPOCHS),HISTORY_FORMAT)

# 灰度化
def rgb2gray(image):
    return np.dot(image[...,:3], [0.299,0.587,0.114])

# one-hot编码
def text2vec(text, length=CAPTCHA_LEN, charset=CAPTCHA_CHARSET):
    text_len = len(text)
    # 验证码长度校验
    if text_len != length:
        raise ValueError(
            "输入字符长度为{}，与所需验证码长度{}不相符".format(text_len,length))
    vec = np.zeros(length*len(charset))
    for i in range(length):
        vec[charset.index(text[i])+i*len(charset)] = 1
    return vec


# 向量转为字符
def vec2text(vector):
    if not isinstance(vector, np.ndarray):
        vector = np.asarray(vector)
    vector = np.reshape(vector, [CAPTCHA_LEN, -1])
    text = ''
    for item in vector:
        text += CAPTCHA_CHARSET[np.argmax(item)]
    return text


# 适配Keras图像数据格式通道
def fit_keras_channels(batch, rows=CAPTCHA_HEIGHT, cols=CAPTCHA_WIDTH):
    if K.image_data_format() == 'channel first':
        batch = batch.reshape(batch.shape[0],1,rows,cols)
        input_shape = (1,rows,cols)
    else:
        batch = batch.reshape(batch.shape[0],rows,cols,1)
        input_shape = (rows,cols,1)
    return batch,input_shape

if __name__ == '__main__':
    # 读取训练集数据
    X_train, Y_train = [],[]
    # glob.glob遍历读取'.jpg'文件
    filename = []
    filename = glob.glob(TRAIN_DATA_DIR + '*.jpg')
    random.seed(time.time())
    random.shuffle(filename)
    for file in filename:
        X_train.append(np.array(Image.open(file)))
        Y_train.append(file.lstrip(TRAIN_DATA_DIR+'\\').rstrip('.jpg'))

    # 预处理训练集图像
    # 将X_train格式转为rgb的np.float32型的numpy数组格式
    X_train = np.array(X_train, dtype=np.float32)
    # 将数据由rgb图转为gray灰度图
    X_train = rgb2gray(X_train)
    # 数据归一化
    X_train = X_train / 255
    # 适配Keras数据通道
    X_train, input_shape = fit_keras_channels(X_train)

    print(X_train.shape, type(X_train))
    print(input_shape)


    # 处理训练集标签
    Y_train = list(Y_train)
    for i in range(len(Y_train)):
    #     print(Y_train[i])
        Y_train[i] = text2vec(Y_train[i])
    Y_train = np.asarray(Y_train)

    print(Y_train.shape, type(Y_train))


    # 读取验证集数据，处理图像和标签
    X_test,Y_test = [],[]
    # 读取验证集数据
    filename = []
    filename = glob.glob(TEST_DATA_DIR + '*.jpg')
    random.seed(time.time())
    random.shuffle(filename)
    for file in filename:
        X_test.append(np.array(Image.open(file)))
        Y_test.append(file.lstrip(TEST_DATA_DIR+'\\').rstrip('.jpg'))
    # 处理图像
    X_test = np.array(X_test, dtype=np.float32)
    X_test = rgb2gray(X_test) / 255
    X_test,_ = fit_keras_channels(X_test)
    # 处理标签
    Y_test = list(Y_test)
    for i in range(len(Y_test)):
        Y_test[i] = text2vec(Y_test[i])
    Y_test = np.asarray(Y_test)

    print(X_test.shape)
    print(Y_test.shape)

    # 创建VGG16模型
    # 创建输入层
    with tf.name_scope('inputs'):
        inputs = Input(shape=input_shape, name='inputs')

    # 第一轮卷积
    with tf.name_scope('con1'):
        conv1 = Conv2D(64, (3,3), name='conv1',padding='same', kernel_initializer='he_uniform')(inputs)
        bn1 = BatchNormalization()(conv1)
        act1 = Activation('relu')(bn1)
        # drop1 = Dropout(0.3)(act1)

        conv2 = Conv2D(64, (3, 3), name='conv2',padding='same', kernel_initializer='he_uniform')(act1)
        bn2 = BatchNormalization()(conv2)
        act2 = Activation('relu')(bn2)

        pool1 = MaxPooling2D(pool_size=(2, 2), padding='same', name='pool1')(act2)

    # 第二轮卷积
    with tf.name_scope('con2'):
        conv3 = Conv2D(128, (3, 3), name='conv3',padding='same', kernel_initializer='he_uniform')(pool1)
        bn3 = BatchNormalization()(conv3)
        act3 = Activation('relu')(bn3)
        # drop2 = Dropout(0.4)(act3)

        conv4 = Conv2D(128, (3, 3), name='conv4',padding='same', kernel_initializer='he_uniform')(act3)
        bn4 = BatchNormalization()(conv4)
        act4 = Activation('relu')(bn4)

        pool2 = MaxPooling2D(pool_size=(2, 2), padding='same', name='pool2')(act4)

    # 第三轮卷积
    with tf.name_scope('con3'):
        conv5 = Conv2D(256, (3,3), name='conv5',padding='same', kernel_initializer='he_uniform')(pool2)
        bn5 = BatchNormalization()(conv5)
        act5 = Activation('relu')(bn5)
        # drop3 = Dropout(0.4)(act5)

        conv6 = Conv2D(256, (3, 3), name='conv6',padding='same', kernel_initializer='he_uniform')(act5)
        bn6 = BatchNormalization()(conv6)
        act6 = Activation('relu')(bn6)
        # drop4 = Dropout(0.4)(act6)

        conv7 = Conv2D(256, (3, 3), name='conv7',padding='same', kernel_initializer='he_uniform')(act6)
        bn7 = BatchNormalization()(conv7)
        act7 = Activation('relu')(bn7)

        pool3 = MaxPooling2D(pool_size=(2, 2), padding='same', name='pool3')(act7)

    # 第四轮卷积
    with tf.name_scope('con4'):
        conv8 = Conv2D(512, (3, 3), name='conv8',padding='same', kernel_initializer='he_uniform')(pool3)
        bn8 = BatchNormalization()(conv8)
        act8 = Activation('relu')(bn8)
        # drop5 = Dropout(0.4)(act8)

        conv9 = Conv2D(512, (3, 3), name='conv9',padding='same', kernel_initializer='he_uniform')(act8)
        bn9 = BatchNormalization()(conv9)
        act9 = Activation('relu')(bn9)
        # drop6 = Dropout(0.4)(act9)

        conv10 = Conv2D(512, (3, 3), name='conv10',padding='same', kernel_initializer='he_uniform')(act9)
        bn10 = BatchNormalization()(conv10)
        act10 = Activation('relu')(bn10)

        pool4 = MaxPooling2D(pool_size=(2, 2), padding='same', name='pool4')(act10)

    # 第五轮卷积
    with tf.name_scope('con5'):
        conv11 = Conv2D(512, (3, 3), name='conv11',padding='same', kernel_initializer='he_uniform')(pool4)
        bn11 = BatchNormalization()(conv11)
        act11 = Activation('relu')(bn11)
        # drop7 = Dropout(0.4)(act11)

        conv12 = Conv2D(512, (3, 3), name='conv12',padding='same', kernel_initializer='he_uniform')(act11)
        bn12 = BatchNormalization()(conv12)
        act12 = Activation('relu')(bn12)
        # drop8 = Dropout(0.4)(act12)

        conv13 = Conv2D(512, (3, 3), name='conv13',padding='same', kernel_initializer='he_uniform')(act12)
        bn13 = BatchNormalization()(conv13)
        act13 = Activation('relu')(bn13)

        pool5 = MaxPooling2D(pool_size=(2, 2), padding='same', name='pool5')(act13)

    # 全连接层
    with tf.name_scope('dense'):
        # 将池化后的数据摊平后输入全连接网络
        x = Flatten()(pool3)
        # Dropout
        x = Dropout(0.5)(x)

        x1 = Dense(4096)(x)
        bnx1 = BatchNormalization()(x1)
        actx1 = Activation('relu')(bnx1)
        drop9 = Dropout(0.4)(actx1)

        x2 = Dense(4096)(drop9)
        bnx2 = BatchNormalization()(x2)
        x = Activation('relu')(bnx2)

        # 创建4个全连接层,区分10类，分别识别4个字符
        x = [Dense(10, activation='softmax', name='func%d'%(i+1))(x) for i in range(4)]

    # 输出层
    with tf.name_scope('outputs'):
        # 将生成的4个字符拼接输出
        outs = Concatenate()(x)

    # 定义模型的输入和输出
    model = Model(inputs=inputs, outputs=outs)
    model.compile(optimizer=OPT, loss=LOSS, metrics=['accuracy'])

    model.summary()
     # 加载训练再训练
    # model.load_weights('./model/train_demo/ captcha_categorical_crossentropy_bs_100_epochs_200.h5')
    # callbacks = [ModelCheckpoint('./model/cnn_best_vgg.h5', save_best_only=True)]

    # 模型训练的过程函数赋值给history
    history = model.fit(X_train,Y_train,
                       batch_size=BATCH_SIZE,
                       epochs=EPOCHS,verbose=2,
                       validation_data=(X_test,Y_test)
                       )

    # 预测样例
    print(vec2text(Y_test[22]))
    yy = model.predict(X_test[22].reshape(1, 60, 160, 1))
    print(vec2text(yy))

    if not gfile.Exists(MODEL_DIR):
        gfile.MakeDirs(MODEL_DIR)

    # 保存模型
    model.save(MODEL_FILE)
    print(MODEL_FILE)

    # 保存模型历史记录
    if not gfile.Exists(HISTORY_DIR):
        gfile.MakeDirs(HISTORY_DIR)

    with open(HISTORY_FILE, 'wb') as f:
        pickle.dump(history.history, f)
    print(HISTORY_FILE)


